{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Solving reinforcement learning tasks using functional evolutionary algorithms\n",
    "\n",
    "The functional implementations of evolutionary algorithms can interact with the object-oriented `Problem` API of EvoTorch. To demonstrate this, we instantiate a `GymNE` problem configured to work on the reinforcement learning task `CartPole-v1`, and we use the functional `pgpe` algorithm to solve it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from evotorch.algorithms.functional import pgpe, pgpe_ask, pgpe_tell\n",
    "from evotorch.neuroevolution import GymNE\n",
    "from datetime import datetime\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Below, we instantiate the reinforcement learning problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "problem = GymNE(\n",
    "    # The id of the gymnasium task:\n",
    "    \"CartPole-v1\",\n",
    "\n",
    "    # Policy architecture to use.\n",
    "    # This can also be given as a subclass of `torch.nn.Module`, or an instantiated\n",
    "    # `torch.nn.Module` object. For simplicity, we use a basic feed-forward neural\n",
    "    # network, that can be expressed as a string.\n",
    "    \"Linear(obs_length, 16) >> Tanh() >> Linear(16, act_length)\",\n",
    "\n",
    "    # Setting `observation_normalization` as True means that stats regarding\n",
    "    # observations will be collected during each population evaluation\n",
    "    # process, and those stats will be used to normalize the future\n",
    "    # observation data (that will be given as input to the policy).\n",
    "    observation_normalization=True,\n",
    "\n",
    "    # Number of actors to be used. Can be an integer.\n",
    "    # The string \"max\" means that the number of actors will be equal to the\n",
    "    # number of CPUs.\n",
    "    num_actors=\"max\",\n",
    ")\n",
    "\n",
    "problem"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "Now that we have instantiated our problem, we make a callable evaluator object from it. This callable evaluator, named `f`, behaves like a function `f(x)`, where `x` can be a single solution (represented by a 1-dimensional tensor), or a population (represented by a 2-dimensional tensor where each row is a solution), or a batch of populations (represented by a tensor with at least 3 dimensions). Upon receiving its argument `x`, `f` uses the problem object to evaluate the solution(s), and return the evalution result(s)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = problem.make_callable_evaluator()\n",
    "f"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "Hyperparameters for `pgpe`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "popsize = 100\n",
    "center_init = problem.make_zeros(num_solutions=1)[0]\n",
    "max_speed = 0.15\n",
    "center_learning_rate = max_speed * 0.75\n",
    "radius = max_speed * 15\n",
    "stdev_learning_rate = 0.1\n",
    "stdev_max_change = 0.2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "We prepare the `pgpe` algorithm and get its initial state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "pgpe_state = pgpe(\n",
    "    # Center of the initial search distribution\n",
    "    center_init=center_init,\n",
    "\n",
    "    # Radius for the initial search distribution\n",
    "    radius_init=radius,\n",
    "\n",
    "    # Learning rates for when updating the center and the standard deviation\n",
    "    # of the search distribution\n",
    "    center_learning_rate=center_learning_rate,\n",
    "    stdev_learning_rate=stdev_learning_rate,\n",
    "\n",
    "    # Maximum relative amount of change for standard deviation.\n",
    "    # Setting this as 0.2 means that an item of the standard deviation vector\n",
    "    # will not be allowed to change more than the 20% of its original value.\n",
    "    stdev_max_change=stdev_max_change,\n",
    "\n",
    "    # The ranking method to be used.\n",
    "    # \"centered\" is a ranking method which assigns the rank -0.5 to the worst\n",
    "    # solution, and +0.5 to the best solution.\n",
    "    ranking_method=\"centered\",\n",
    "\n",
    "    # The optimizer to be used. Can be \"clipup\", \"adam\", or \"sgd\".\n",
    "    optimizer=\"clipup\",\n",
    "\n",
    "    # Optimizer-specific hyperparameters:\n",
    "    optimizer_config={\"max_speed\": max_speed},\n",
    "\n",
    "    # Whether or not symmetric sampling will be used.\n",
    "    symmetric=True,\n",
    "\n",
    "    # We want to maximize the evaluation results.\n",
    "    # In the case of reinforcement learning tasks declared via `GymNE`,\n",
    "    # evaluation results represent the cumulative rewards.\n",
    "    objective_sense=\"max\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "Below is the main loop of the evolutionary search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We run the evolutionary search for this many generations:\n",
    "num_generations = 40\n",
    "\n",
    "last_report_time = datetime.now()\n",
    "\n",
    "# This is the interval (in seconds) for reporting the status:\n",
    "reporting_interval = 1\n",
    "\n",
    "for generation in range(1, 1 + num_generations):\n",
    "    # Get a population from the pgpe algorithm\n",
    "    population = pgpe_ask(pgpe_state, popsize=popsize)\n",
    "\n",
    "    # Evaluate the fitnesses\n",
    "    fitnesses = f(population)\n",
    "\n",
    "    # Inform pgpe of the fitnesses and get its next state\n",
    "    pgpe_state = pgpe_tell(pgpe_state, population, fitnesses)\n",
    "\n",
    "    # If it is time to report, print the status\n",
    "    tnow = datetime.now()\n",
    "    if (tnow - last_report_time).total_seconds() > reporting_interval:\n",
    "        print(\"generation:\", generation, \"median eval:\", torch.median(fitnesses))\n",
    "        last_report_time = tnow"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "Here is the center point of the most recent search distribution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = pgpe_state.optimizer_state.center\n",
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14",
   "metadata": {},
   "source": [
    "Now, we visualize the agent evolved by our functional `pgpe`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "problem.visualize(x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
